
module cenf

using JLD, JuMP, MAT, Roots
using KNITRO
# using Ipopt
using Gadfly
using MathOptInterface
using CSV
using SparseArrays, LinearAlgebra, Statistics
using Logging
using DelimitedFiles
using SpecialFunctions
using DataFrames
using Random

global isocodes, indices, lhs, z_gm, z_f, δ, δt, countrynames, countrynames_full, sectornames, consumptionShares

global const CONST_DIM_ETA = 2

export solveCountry, readData, solveWithStartingβη_gm, solveWithStartingβη_f, eval_ces_problem!

export z_gm, z_f, δ, δt

# readData(): read data file (2016 version)
function readData(root::String)

  global isocodes, indices, lhs, z_gm, z_f, δ, δt, countrynames, countrynames_full, sectornames, consumptionShares

  # File format:
  # countrycode upstream downstream fraction delta delta_withtime n_div_gm n_div_f
  data = CSV.read("$(root)/data/dataset.csv")

  # Initialize
  isocodes=zeros(Int64,size(data,1))
  indices=zeros(Int64,999)
  lhs=zeros(Float64,109,35,35)
  z_gm=zeros(Float64,109,35,35)
  z_f=zeros(Float64,109,35,35)
  δ=zeros(Float64,109,35,35)
  δt=zeros(Float64,109,35,35)

  downstream = data[:,3]
  upstream = data[:,2]

  oldisocode = 0
  index = 0

  for i=1:size(data,1)
    if oldisocode!=data[i,1]
      index+=1
    end
    # data format:
    # countrycode upstream downstream fraction delta delta_withtime n_div_gm n_div_f
    # ISO code as a function of the index
    isocodes[index]=data[i,1]
    # indices as a function of the ISO code
    indices[round(Int64,data[i,1])]=index
    # we set all matrix variables so that the downstream sector
    # is the row index (first), and the upstream sector is the
    # column index (second), to be more in line with the paper
    # (notation X_ni/X_n)
    lhs[index,downstream[i],upstream[i]]=data[i,4]
    # truncate δ at 1
    δ[index,downstream[i],upstream[i]] = data[i,5] < 1.0 ? data[i,5] : 1.0

    δt[index,downstream[i],upstream[i]]=data[i,6]
    z_gm[index,downstream[i],upstream[i]]=data[i,7]
    z_f[index,downstream[i],upstream[i]]=data[i,8]

    oldisocode=data[i,1];
  end

  # infuse some random noise: 20 % of a standard deviation
  f = std(vec(z_f[1,:,:]))  
  rng = MersenneTwister(12345)
  for i = 1:35
    for n = 1:35
      r = rand(rng, Float64)
      r = 0.8 + 0.4*r
      z_f[:,n,i] = z_f[:,n,i] .* r
    end
  end
  # # add a small number to remove the numerical instability at z=0
  # z_gm .+= 0.001
  # z_f .+= 0.001
  
  # ----------------------------------------------
  # READ COUNTRY NAMES ETC
  # ----------------------------------------------
  file = matread("$(root)/data/countrynames.mat")
  countrynames = file["countrynames"]
  file = matread("$(root)/data/countrynames_full.mat")
  countrynames_full = file["countrynames_full"]
  file = matread("$(root)/data/sectornames.mat")
  sectornames = file["sectornames"]

  # ----------------------------------------------
  # READ CONSUMPTION SHARE DATA
  # ----------------------------------------------
  cshares = CSV.read("$(root)/data/consumptionshares.csv", header=false)
  iso=cshares[:,1]
  sector=cshares[:,2]
  consumptionShares = zeros(Float64,109,35)
  for i=1:size(cshares, 1)
    consumptionShares[indices[iso[i]],sector[i]]=cshares[i,3];
  end

end

# function eval_ces_problem!(fixedParameters::Dict{String,Array}, startingParameters::Dict{String,Array}, outputParameters::Dict{Any,Any}, settings::Dict{String,Any})

#   outstring = string("out/", string(now()))
#   outfile_csv = string(outstring, ".csv")
#   outfile_jld = string(outstring, ".jld")

#   # omegameasure: 1=Z_GM, 2=Z_F

#   CONST_DEFAULT_FTOL = 1e-8
#   CONST_DEFAULT_XTOL = 1e-8

#   if get(settings,"ftol",[]) != []
#     ftol=get(settings,"ftol",[])
#   else
#     ftol=CONST_DEFAULT_FTOL
#   end
#   if get(settings,"xtol",[]) != []
#     xtol=get(settings,"xtol",[])
#   else
#     xtol=CONST_DEFAULT_XTOL
#   end
#   if get(settings,"omegameasure",[]) == 1
#     z=z_gm
#   elseif get(settings,"omegameasure",[]) == 2
#     z=z_f
#   else
#     # DEFAULT: z_gm
#     println("Defaulting to using z_gm")
#     z=z_gm
#   end

#   # number of iterations
#   if get(settings,"maxiter",[]) != []
#     maxiter = get(settings,"maxiter",[]);
#   else
#     println("Defaulting to maxiter=500.")
#     maxiter = 500;
#   end

#   # whether we should use KNITRO or IPOPT
#   if get(settings,"knitro",[]) != []
#     if get(settings,"knitro",[])==1.0
#       println("Using KNITRO.")
#       use_knitro = 1;
#     else
#       println("Using IPOPT.")
#       use_knitro = 0;
#     end
#   else
#     # default to IPOPT
#     println("Defaulting to IPOPT.")
#     use_knitro = 0;
#   end

#   if use_knitro==1
#     m = Model(solver=KnitroSolver(
#     KTR_PARAM_MAXIT=maxiter,
#     KTR_PARAM_FTOL=ftol,
#     KTR_PARAM_XTOL=xtol,
#     KTR_PARAM_OPTTOL=1e-8,
#     KTR_PARAM_ALGORITHM=1,
#     KTR_PARAM_HESSOPT=1, # how to get hessians: default is exact (1)
#     KTR_PARAM_BLASOPTION=0,
#     KTR_PARAM_BAR_MAXCROSSIT=2,
#     KTR_PARAM_LINSOLVER=6, # which linear solver. 6=Intel MKL Pardiso
#     KTR_PARAM_LINSOLVER_OOC=1, # if LINSOLVER=6, use out of core option? 1=maybe
#     #KTR_PARAM_PAR_NUMTHREADS=2, # number of parallel threads for BLAS
#     KTR_PARAM_BAR_MURULE=2,
#     KTR_PARAM_PIVOT=1e-08,
#     KTR_PARAM_HONORBNDS=0,
#     KTR_PARAM_OUTLEV=3,
#     KTR_PARAM_OUTMODE=2,
#     KTR_PARAM_OUTAPPEND=1))
#   else
#     error("ipopt not implemented anymore")
#     # m = Model(solver=IpoptSolver(max_iter = maxiter, mu_strategy="adaptive") )
#   end

#   objects = Dict()
#   ccpar = 0 # number of cross-country parameters to estimate -- count variable

#   if get(fixedParameters,"ρ",[]) != []
#     fixedρ=get(fixedParameters,"ρ",[])
#     ρ = fixedρ[1]
#   elseif get(startingParameters,"ρ",[]) != []
#     @variable(m, ρ)
#     startρ=get(startingParameters,"ρ",[])
#     setvalue(ρ, startρ[1])
#     ccpar += 1
#     # impose a constraint that ρ>0
#     @constraint(m, ρ>=0)
#   else
#     # parameter not specified!
#     throw(UndefVarError(:ρ))
#   end
#   objects["ρ"]=ρ

#   if get(fixedParameters,"θ",[]) != []
#     fixedθ=get(fixedParameters,"θ",[])
#     θ = fixedθ[1]
#   elseif get(startingParameters,"θ",[]) != []
#     @variable(m, θ)
#     startθ=get(startingParameters,"θ",[])
#     setvalue(θ, startθ[1])
#     ccpar += 1
#   else
#     # parameter not specified!
#     throw(UndefVarError(:θ))
#   end
#   objects["θ"]=θ

#   if get(fixedParameters,"β",[]) != []
#     β = get(fixedParameters,"β",[])
#   elseif get(startingParameters,"β",[]) != []
#     @variable(m, β[1:2])
#     setvalue(β, get(startingParameters,"β",[]))
#     ccpar += 2
#   else
#     # parameter not specified!
#     throw(UndefVarError(:β))
#   end
#   objects["β"]=β

#   if get(fixedParameters,"η",[]) != []
#     η = get(fixedParameters,"η",[])
#   elseif get(startingParameters,"η",[]) != []
#     @variable(m, η[1:2])
#     setvalue(η, get(startingParameters,"η",[]))
#     ccpar += 2
#   else
#     # parameter not specified!
#     throw(UndefVarError(:η))
#   end
#   objects["η"]=η

#   if get(fixedParameters,"logT",[]) != []
#     logT = get(fixedParameters,"logT",[])
#   elseif get(startingParameters,"logT",[]) != []
#     @variable(m, logT[1:109,1:35])
#     setvalue(logT, get(startingParameters,String("logT"),[]))
#   else
#     # parameter not specified!
#     throw(UndefVarError(:logT))
#   end
#   objects["logT"]=logT

#   if get(fixedParameters,"logS",[]) != []
#     logS = get(fixedParameters,"logS",[])
#   elseif get(startingParameters,"logS",[]) != []
#     @variable(m, logS[1:109,1:35])
#     setvalue(logS, get(startingParameters,String("logS"),[]))
#   else
#     # parameter not specified!
#     throw(UndefVarError(:logS))
#   end
#   objects["logS"]=logS

#   if get(fixedParameters,"γ",[]) != []
#     γ = get(fixedParameters,"γ",[])
#   elseif get(startingParameters,"γ",[]) != []
#     @variable(m, γ[1:35,1:35] >= 0.0)
#     setvalue(γ, get(startingParameters,"γ",[]))
#     for n=1:35
#       @constraint(m, sum(γ[n,:])==1)
#       @constraint(m, γ[n,:].>=(0))
#     end
#   else
#     # parameter not specified!
#     throw(UndefVarError(:γ))
#   end
#   objects["γ"]=γ

#   # constraint that we dont get NaNs
#   if get(settings,"betaconstraint",[]) != []
#     if get(settings,"betaconstraint",[])==1.0
#       println("Registering beta constraint")
#       maxDelta=maximum(δ)
#       @constraint(m,β[1]*maxDelta+β[2]*maxDelta*maxDelta<=1.0)
#       @constraint(m,β[1]>=0.0)
#     else
#       # no constraint
#     end
#   else
#     # default to no constraint
#   end

#   # ORIGINAL VERSION
#   @NLexpression(m, expXi[c=1:109,n=1:35,i=1:35], exp(-logS[c,n]+logT[c,i]-θ*log(ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i]))))
#   @NLexpression(m, pn[c=1:109,n=1:35], sum(γ[n,j]*(1+expXi[c,n,j])^((ρ-1)/θ) for j=1:35))
#   @NLexpression(m, g[c=1:109,n=1:35,i=1:35], (γ[n,i]*(1+expXi[c,n,i])^((ρ-1)/θ) / pn[c,n] )/(1+1/expXi[c,n,i]))

#   # FASTER VERSION
#   # this crashes
#   # @variable(m, Xi[1:109,1:35,1:35])
#   # @NLconstraint(m, xidef[c=1:109,n=1:35,i=1:35], Xi[c,n,i] == -logS[c,n]+logT[c,i]-θ*log(ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i])) )
#   # @NLexpression(m, pn[c=1:109,n=1:35], sum(γ[n,j]*(1+exp(Xi[c,n,j]))^((ρ-1)/θ) for j=1:35))
#   # @NLexpression(m, g[c=1:109,n=1:35,i=1:35], (γ[n,i]*(1+exp(Xi[c,n,i]))^((ρ-1)/θ) / pn[c,n] )/(1+exp(-Xi[c,n,i])))

#   # CD:
#   #@NLexpression(m, g[c=1:109,n=1:35,i=1:35], (γ[n,i] )*1/(1+exp(-Xi[c,n,i])))

#   @NLobjective(m, Max, sum(lhs[c,n,i]*log(g[c,n,i])-g[c,n,i] for n=1:35 for i=1:35 for c=1:109))

#   @time solve(m)

#   fval = getobjectivevalue(m)

#   for par in ["β";"η";"logT";"logS";"γ"]
#     if typeof(objects[par])==JuMP.Variable || typeof(objects[par])==Array{JuMP.Variable} || typeof(objects[par])==Array{JuMP.Variable,1} || typeof(objects[par])==Array{JuMP.Variable,2}
#       outputParameters[par]=value.(objects[par])
#     else
#       outputParameters[par]=objects[par]
#     end
#   end
#   outputParameters["fval"]=fval[1]

#   # get objective function, gradient, and hessians
#   if typeof(θ)==Float64
#     ESTIMATE_THETA = 0
#   else
#     ESTIMATE_THETA = 1
#   end

#   # first put the solution vector into the 'values' object
#   values = zeros(Float64,ccpar+109*35*2 + 35*35)
#   if ESTIMATE_THETA == 1
#     values[linearindex(θ)] = outputParameters["θ"]
#   else
#     values[linearindex(β[1])] = outputParameters["β"][1]
#     values[linearindex(β[2])] = outputParameters["β"][2]
#   end
#   if typeof(ρ)!=Float64
#     values[linearindex(ρ)] = outputParameters["ρ"]
#   end
#   values[linearindex(η[1])] = outputParameters["η"][1]
#   values[linearindex(η[2])] = outputParameters["η"][2]
#   for c=1:109
#     for i=1:35
#       values[linearindex(logT[c,i])] = outputParameters["logT"][c,i]
#     end
#   end
#   for c=1:109
#     for n=1:35
#       values[linearindex(logS[c,n])] = outputParameters["logS"][c,n]
#     end
#   end
#   for n=1:35
#     for i=1:35
#       values[linearindex(γ[n,i])] = outputParameters["γ"][n,i]
#     end
#   end


#   # # create evaluator object and get objective function value
#   # evaluator = JuMP.NLPEvaluator(m)
#   # MathProgBase.initialize(evaluator, [:Grad,:HessVec,:Hess])
#   # objval = MathProgBase.eval_f(evaluator, values) # == sin(2.0) + sin(3.0)

#   # # get gradient
#   # ∇f = zeros(Float64,ccpar+109*35*2 + 35*35)
#   # MathProgBase.eval_grad_f(evaluator, ∇f, values)

#   # # get hessian
#   # structure = MathProgBase.hesslag_structure(evaluator)
#   # ∇2fvec = zeros(Float64,size(structure[1],1))
#   # MathProgBase.eval_hesslag(evaluator, ∇2fvec, values, 1.0, zeros(Float64,ccpar+109*35*2 + 35*35))

#   # hess = sparse([structure[1];structure[2]],[structure[2];structure[1]],[∇2fvec;∇2fvec] )

#   # VCov = Array{Float64,2};

#   # if ESTIMATE_THETA == 1
#   #   if maximum([linearindex(θ);linearindex(η[1]);linearindex(η[2])])==ccpar
#   #     # we're fine...
#   #     try
#   #       temp =  hess[1:ccpar,1:ccpar]\∇f[1:ccpar];
#   #       VCov = temp*temp';
#   #     catch
#   #       @show hess[1:ccpar,1:ccpar]
#   #       @show ∇f[1:ccpar]

#   #       println("Cheap VCov matrix didn't work. Trying full...")
#   #       try
#   #         temp = hess\∇f
#   #         VCov = temp*temp'
#   #       catch
#   #         println("Full VCov failed.")
#   #       end
#   #       #error("Problem calculating the VCov matrix")
#   #     end
#   #   else
#   #     println("Cheap VCov matrix didn't work. Trying full...")
#   #     try
#   #       temp = hess\∇f
#   #       VCov = temp*temp'
#   #     catch
#   #       println("Full VCov failed.")
#   #     end
#   #   end
#   # else
#   #   if maximum([linearindex(β[1]);linearindex(β[2]);linearindex(η[1]);linearindex(η[2])])==ccpar
#   #     # we can do this
#   #     # get just the covar matrix for the cc parameters
#   #     try
#   #       temp =  hess[1:ccpar,1:ccpar]\∇f[1:ccpar];
#   #       VCov = temp*temp';
#   #     catch
#   #       @show hess[1:ccpar,1:ccpar]
#   #       show(hess[1:ccpar,1:ccpar])
#   #       @show ∇f[1:ccpar]
#   #       show(∇f[1:ccpar])
#   #       println("Cheap VCov matrix didn't work. Trying full...")
#   #       try
#   #         temp = hess\∇f
#   #         VCov = temp*temp'
#   #       catch
#   #         println("Full VCov failed.")
#   #       end
#   #     end
#   #   else
#   #     println("Cheap VCov matrix didn't work. Trying full...")
#   #     try
#   #       temp = hess\∇f
#   #       VCov = temp*temp'
#   #     catch
#   #       println("Full VCov failed.")
#   #     end
#   #   end
#   # end

#   # try
#   #   outputParameters["Varθ"] = VCov[linearindex(θ),linearindex(θ)]
#   # end
#   # try
#   #   outputParameters["Varρ"] = VCov[linearindex(ρ),linearindex(ρ)]
#   # end
#   # try
#   #   outputParameters["Var_β1"] = VCov[linearindex(β[1]),linearindex(β[1])]
#   #   outputParameters["Var_β2"] = VCov[linearindex(β[2]),linearindex(β[2])]
#   # end
#   # try
#   #   outputParameters["Var_η1"] = VCov[linearindex(η[1]),linearindex(η[1])]
#   #   outputParameters["Var_η2"] = VCov[linearindex(η[2]),linearindex(η[2])]
#   # end

#   # get fitted values
#   fitted = zeros(Float64,109,35,35)
#   XiArray = zeros(Float64,109,35,35)
#   for c=1:109
#     for n=1:35
#       for i=1:35
#         # version with a fixed θ
#         XiArray[c,n,i] = -outputParameters["logS"][c,n]+outputParameters["logT"][c,i]-θ*log(ifelse(1/(1-outputParameters["β"][1]*δ[c,n,i]-outputParameters["β"][2]*δ[c,n,i]*δ[c,n,i])<1+outputParameters["η"][1]*z[c,n,i]+outputParameters["η"][2]*z[c,n,i]*z[c,n,i],1/(1-outputParameters["β"][1]*δ[c,n,i]-outputParameters["β"][2]*δ[c,n,i]*δ[c,n,i]),1+outputParameters["η"][1]*z[c,n,i]+outputParameters["η"][2]*z[c,n,i]*z[c,n,i]));
#       end
#     end
#   end
#       for c=1:109
#     for n=1:35
#       for i=1:35
#         fitted[c,n,i]=(outputParameters["γ"][n,i]*(1+exp(XiArray[c,n,i]))^((ρ-1)/θ) / sum(outputParameters["γ"][n,j]*(1+exp(XiArray[c,n,j]))^((ρ-1)/θ) for j=1:35) )*1/(1+exp(-XiArray[c,n,i]))
#       end
#     end
#   end
#   outputParameters["pseudor2"] = cor(vec(fitted),vec(lhs))^2

#   # save csv output file
#   # format is countrycode downstream upstream lhs fitted
#   outarray = zeros(Float64,109*35*35,5)
#   idx=1
#   for c=1:109
#     for n=1:35
#       for i=1:35
#         outarray[idx,:]=[isocodes[c] n i lhs[c,n,i] fitted[c,n,i] ]
#         idx=idx+1
#       end
#     end
#   end
#   try
#     writedlm(outfile_csv, outarray)
#   catch
#     println("Could not save iteration output. Continuing...")
#   end

#   if typeof(θ)==Float64
#     θout = θ
#     βout = value.(β)
#   else
#     θout = value.(θ)
#     βout = β
#   end
#   if typeof(ρ)==Float64
#     ρout = ρ
#   else
#     ρout = value.(ρ)
#   end
#   # not estimating theta
#   info("Exiting eval_problem with ρ=",ρout,", θ=",θout,", β=",βout,", η=", value.(η), ", fval=", fval)
#   println("Exiting eval_problem with ρ=",ρout,", θ=",θout,", β=",βout,", η=", value.(η), ", fval=", fval)

#   # try
#   #   save(outfile_jld, "out", outputParameters)
#   # catch
#   #   println("Could not save iteration output. Continuing...")
#   # end

#   return fval

# end

function eval_problem!(fixedParameters::Dict{String,Array}, startingParameters::Dict{String,Array}, outputParameters::Dict{Any,Any}, settings::Dict{String,Any})

  # omegameasure: 1=Z_GM, 2=Z_F

  CONST_DEFAULT_FTOL = 1e-8
  CONST_DEFAULT_XTOL = 1e-8

  if get(settings,"ftol",[]) != []
    ftol=get(settings,"ftol",[])
  else
    ftol=CONST_DEFAULT_FTOL
  end
  if get(settings,"xtol",[]) != []
    xtol=get(settings,"xtol",[])
  else
    xtol=CONST_DEFAULT_XTOL
  end
  if get(settings,"omegameasure",[]) == 1
    z=z_gm
  elseif get(settings,"omegameasure",[]) == 2
    z=z_f
  else
    # DEFAULT: z_gm
    println("Defaulting to using z_gm")
    z=z_gm
  end
  if (get(settings,"savefitted",[]) == 1) & (get(settings,"outfile",[]) != [])
    savefitted = 1
    outfile = get(settings,"outfile",[])
  else
    # default
    savefitted = 0
  end

  m = JuMP.Model(with_optimizer(KNITRO.Optimizer,
              maxit=500,
              opttol=1e-8,
              ftol=ftol,
              xtol=xtol,
              hessopt=1, # how to get hessians: default is exact (1)
              blasoption=0,
              bar_maxcrossit=2,
              linsolver=6,
              linsolver_ooc=1, # if LINSOLVER=6, use out of core option? 1=maybe
              bar_murule=2,
              linsolver_pivottol=1e-08,
              honorbnds=0,
              outlev=3,
              outmode=0,    # 0 = screen only, 2 = screen and knitro.log
              outappend=1,
              newpoint=0,   # 0 = off, 2= save info about new point to knitro_newpoint.log
              ms_enable=0))

  objects = Dict()

  if get(fixedParameters,"θ",[]) != []
    fixedθ=get(fixedParameters,"θ",[])
    θ = fixedθ[1]
  elseif get(startingParameters,"θ",[]) != []
    startθ=get(startingParameters,"θ",[])
    @variable(m, θ, start = startθ[1])
  else
    # parameter not specified!
    throw(UndefVarError(:θ))
  end
  objects["θ"]=θ

  if get(fixedParameters,"β",[]) != []
    β = get(fixedParameters,"β",[])
  elseif get(startingParameters,"β",[]) != []
    startβ = get(startingParameters,"β",[])
    @variable(m, β[i=1:2], start = startβ[i])
  else
    # parameter not specified!
    throw(UndefVarError(:β))
  end
  objects["β"]=β

  if get(fixedParameters,"η",[]) != []
    η = get(fixedParameters,"η",[])
  elseif get(startingParameters,"η",[]) != []
    startη = get(startingParameters,"η",[])
    @variable(m, η[i=1:CONST_DIM_ETA], start = startη[i])
  else
    # parameter not specified!
    throw(UndefVarError(:η))
  end
  objects["η"]=η

  if get(fixedParameters,"logT",[]) != []
    logT = get(fixedParameters,"logT",[])
  elseif get(startingParameters,"logT",[]) != []
    s = get(startingParameters,String("logT"),[])
    @variable(m, logT[c=1:109,i=1:35], start = s[c,i])
  else
    # parameter not specified!
    throw(UndefVarError(:logT))
  end
  objects["logT"]=logT

  if get(fixedParameters,"logS",[]) != []
    logS = get(fixedParameters,"logS",[])
  elseif get(startingParameters,"logS",[]) != []
    s = get(startingParameters,String("logS"),[])
    @variable(m, logS[c=1:109,n=1:35], start = s[c,n] )
  else
    # parameter not specified!
    throw(UndefVarError(:logS))
  end
  objects["logS"]=logS

  if get(fixedParameters,"γ",[]) != []
    γ = get(fixedParameters,"γ",[])
  elseif get(startingParameters,"γ",[]) != []
    s = get(startingParameters,"γ",[])
    @variable(m, γ[n=1:35,i=1:35] >= 0.0, start = s[n,i])
    for n=1:35
      @constraint(m, sum(γ[n,:])==1)
      @constraint(m, γ[n,:].>=(0))
    end
  else
    # parameter not specified!
    throw(UndefVarError(:γ))
  end
  objects["γ"]=γ

  # constraint that we dont get NaNs
  if get(settings,"betaconstraint",[]) != []
    if get(settings,"betaconstraint",[])==1.0
      maxDelta=maximum(δ)
      @constraint(m,β[1]*maxDelta+β[2]*maxDelta*maxDelta<=1)
      @constraint(m,β[1]>=0)
    else
      # no constraint
    end
  else
    # default to no constraint
  end

  # CONST_DIM_ETA
  #println("Starting estimation with β[1]=$(startβ[1]), β[2]=$(startβ[2]), η[1]=$(startη[1]), η[2]=$(startη[2])") #, η[3]=$(startη[3]), η[4]=$(startη[4])
  # CONST_DIM_ETA
  @NLexpression(m, g[c=1:109,n=1:35,i=1:35], γ[n,i]*1/(1+exp(logS[c,n]-logT[c,i]+θ*log(ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i])))))
  @NLobjective(m, Max, sum(lhs[c,n,i]*log(g[c,n,i])-g[c,n,i] for n=1:35 for i=1:35 for c=1:109))

  optimize!(m)

  @show termination_status(m)

  fval = objective_value(m)

  for par in ["θ";"β";"η";"logT";"logS";"γ"]
    if isa(objects[par], Float64) || isa(objects[par], Array{Float64,1}) || isa(objects[par], Array{Float64,2})
      outputParameters[par]=objects[par]
    else
      outputParameters[par]=value.(objects[par])
    end
  end
  outputParameters["fval"]=fval[1]

  # get objective function, gradient, and hessians
  if typeof(θ)==Float64
    ESTIMATE_THETA = 0
  else
    ESTIMATE_THETA = 1
  end

  # number of cross-country parameters
  if ESTIMATE_THETA == 1
    ccpar = 1 + CONST_DIM_ETA  # if estimating theta and eta
  else
    ccpar = 2 + CONST_DIM_ETA # if estimating beta and eta
  end

  # first put the solution vector into the 'values' object
  values = zeros(Float64,ccpar+109*35*2 + 35*35)
  if ESTIMATE_THETA == 1
    values[JuMP.index(θ).value] = outputParameters["θ"]
  else
    values[JuMP.index(β[1]).value] = outputParameters["β"][1]
    values[JuMP.index(β[2]).value] = outputParameters["β"][2]
  end
  for i = 1:CONST_DIM_ETA
    values[JuMP.index(η[i]).value] = outputParameters["η"][i]
  end
  for c=1:109
    for i=1:35
      values[JuMP.index(logT[c,i]).value] = outputParameters["logT"][c,i]
    end
  end
  for c=1:109
    for n=1:35
      values[JuMP.index(logS[c,n]).value] = outputParameters["logS"][c,n]
    end
  end
  for n=1:35
    for i=1:35
      values[JuMP.index(γ[n,i]).value] = outputParameters["γ"][n,i]
    end
  end

  # create evaluator object and get objective function value
  evaluator = JuMP.NLPEvaluator(m)
  MOI.initialize(evaluator, [:Grad,:HessVec,:Hess])
  objval = MOI.eval_objective(evaluator, values) # == sin(2.0) + sin(3.0)

  # get gradient
  ∇f = zeros(Float64,ccpar+109*35*2 + 35*35)
  MOI.eval_objective_gradient(evaluator, ∇f, values)

  # get hessian
  structure = MOI.hessian_lagrangian_structure(evaluator);
  ∇2fvec = zeros(Float64, size(structure,1));
  MOI.eval_hessian_lagrangian(evaluator, ∇2fvec, values, 1.0, zeros(Float64,ccpar+109*35*2 + 35*35));

  VCov = zeros(Float64, size(∇f,1),size(∇f,1));

  function dense_hessian(hessian_sparsity, V)
    I = [i for (i,j) in hessian_sparsity]
    J = [j for (i,j) in hessian_sparsity]
    raw = sparse(I, J, V)
    return Symmetric(raw + raw' - sparse(diagm(0=>diag(raw))))
  end

  hess = -dense_hessian(structure,∇2fvec)

  try
    temp = hess\∇f
    VCov = temp*temp'

    error("Bla")

  catch

    println("Error computing covariance matrix. Trying cheap variation...")

    if ESTIMATE_THETA == 1
      # CONST_DIM_ETA
      if maximum([JuMP.index(θ).value;JuMP.index(η[1]).value;JuMP.index(η[2]).value])==ccpar
        # we're fine...
        try
          temp =  hess[1:ccpar,1:ccpar]\∇f[1:ccpar];
          VCov = temp*temp';
        println("Success.")
        catch
          @show hess[1:ccpar,1:ccpar]
          show(hess[1:ccpar,1:ccpar])
          @show ∇f[1:ccpar]
          show(∇f[1:ccpar])
        end
      else
        # we cant do this
        println("No way. Check the ordering of the variables. This shouldnt happen...")
      end
    else
      # CONST_DIM_ETA
      if maximum([JuMP.index(β[1]).value;JuMP.index(β[2]).value;JuMP.index(η[1]).value;JuMP.index(η[2]).value])==ccpar
        # we can do this
        # get just the covar matrix for the cc parameters
        try
          temp =  hess[1:ccpar,1:ccpar]\∇f[1:ccpar];
          VCov = temp*temp';
        catch
          @show hess[1:ccpar,1:ccpar]
          show(hess[1:ccpar,1:ccpar])
          @show ∇f[1:ccpar]
          show(∇f[1:ccpar])
        end
        println("Success.")
      else
        # no way...
        println("No way. Check the ordering of the variables. This shouldnt happen...")
      end
    end

  end

  if ESTIMATE_THETA == 1
    outputParameters["Varθ"] = VCov[JuMP.index(θ).value,JuMP.index(θ).value]
  else
    outputParameters["Var_β1"] = VCov[JuMP.index(β[1]).value,JuMP.index(β[1]).value]
    outputParameters["Var_β2"] = VCov[JuMP.index(β[2]).value,JuMP.index(β[2]).value]
  end
  for i=1:CONST_DIM_ETA
    outputParameters["Var_η$(i)"] = VCov[JuMP.index(η[i]).value,JuMP.index(η[i]).value]
  end

  # get fitted values
  fitted = zeros(Float64,109,35,35)
  dI = zeros(Float64,109,35,35)
  for c=1:109
    for n=1:35
      for i=1:35
        # CONST_DIM_ETA
        dI[c,n,i] = 1.0+outputParameters["η"][1]*z[c,n,i]+outputParameters["η"][2]*z[c,n,i]*z[c,n,i]
        fitted[c,n,i]=outputParameters["γ"][n,i]*1/(1+exp(outputParameters["logS"][c,n]-outputParameters["logT"][c,i]+outputParameters["θ"]*log(ifelse(1/(1-outputParameters["β"][1]*δ[c,n,i]-outputParameters["β"][2]*δ[c,n,i]*δ[c,n,i])< dI[c,n,i] ,1/(1-outputParameters["β"][1]*δ[c,n,i]-outputParameters["β"][2]*δ[c,n,i]*δ[c,n,i]),dI[c,n,i]))))
      end
    end
  end
  outputParameters["pseudor2"] = cor(vec(fitted),vec(lhs))^2
  if savefitted==1
    # format is countrycode downstream upstream lhs fitted
    outarray = zeros(Float64,109*35*35,5)
    idx=1
    for c=1:109
      for n=1:35
        for i=1:35
          outarray[idx,:]=[isocodes[c] n i lhs[c,n,i] fitted[c,n,i] ]
          idx=idx+1
        end
      end
    end
    writedlm(outfile, outarray)
  end

  if typeof(θ)==Float64
    # not estimating theta
    @info "Exiting eval_problem" θ outputParameters["β"] outputParameters["η"] fval
    # info("Exiting eval_problem with θ=",θ,", β=",value.(β),", η=", value.(η), ", fval=", fval)
    # println("Exiting eval_problem with θ=",θ,", β=",value.(β),", η=", value.(η), ", fval=", fval)
  else
    # probably estimating theta, but not beta
    @info "Exiting eval_problem" outputParameters["θ"] β outputParameters["η"] fval
    # info("Exiting eval_problem with θ=",value.(θ),", β=",β,", η=", value.(η), ", fval=", fval)
    # println("Exiting eval_problem with θ=",value.(θ),", β=",β,", η=", value.(η), ", fval=", fval)
  end


  return fval

end

function solveWithFixedTheta(omegameasure::Float64, startβ::Vector{Float64}, startη::Vector{Float64}, startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2})
  outputParameters = solveWithFixedTheta(omegameasure,startβ,startη,startγ,startlogT,startlogS,"")
  return outputParameters
end
function solveWithFixedTheta(omegameasure::Float64, startβ::Vector{Float64}, startη::Vector{Float64}, startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2}, saveoutfilename::String)
  fixedθ = 4.0

  fixedParameters = Dict{String,Array}("θ"=>[fixedθ])
  startingParameters = Dict{String,Array}("β"=>startβ, "η"=>startη, "γ"=> startγ, "logT"=> startlogT, "logS"=>startlogS)
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>omegameasure,"betaconstraint"=>1.0)
  if saveoutfilename!=""
    settings["savefitted"]=1
    settings["outfile"]=saveoutfilename
  end
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

# function solveWithStartingβη_gm(startβ::Vector{Float64},startη::Vector{Float64})
#
#   fixedθ = 4.0
#
#   fixedParameters = Dict{String,Array}("θ"=>[fixedθ])
#   startingParameters = Dict{String,Array}("β"=>startβ, "η"=>startη, "γ"=> 1/35.*ones(Float64,35,35), "logT"=> zeros(Float64,109,35), "logS"=>zeros(Float64,109,35))
#   outputParameters=Dict()
#   settings = Dict{String,Float64}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>1.0,"betaconstraint"=>1.0)
#   eval_problem!(fixedParameters, startingParameters, outputParameters, settings)
#
#   return outputParameters
# end
#
# function solveWithStartingβη_f(startβ::Vector{Float64},startη::Vector{Float64})
#
#   fixedθ = 4.0
#
#   fixedParameters = Dict{String,Array}("θ"=>[fixedθ])
#   startingParameters = Dict{String,Array}("β"=>startβ, "η"=>startη, "γ"=> 1/35.*ones(Float64,35,35), "logT"=> zeros(Float64,109,35), "logS"=>zeros(Float64,109,35))
#   outputParameters=Dict()
#   settings = Dict{String,Float64}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>2.0,"betaconstraint"=>1.0)
#   eval_problem!(fixedParameters, startingParameters, outputParameters, settings)
#
#   return outputParameters
# end

function solveCESWithρθ(omegameasure::Float64,startβ::Vector{Float64}, startη::Vector{Float64})
  #fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("ρ"=>[4.0], "θ" => [4.0])
  startingParameters = Dict{String,Array}("β" => startβ,"η"=>startη, "γ"=> 1/35 .* ones(Float64,35,35), "logT"=> zeros(Float64,109,35), "logS"=>zeros(Float64,109,35))
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-8,"xtol"=>1e-8,"omegameasure"=>omegameasure,"betaconstraint"=>1.0, "knitro"=>1.0, "maxiter"=>1000)
  eval_ces_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

function solveCESWithρθ(omegameasure::Float64,startβ::Vector{Float64}, startη::Vector{Float64},startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2}, saveoutfilename::String)
  #fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("ρ"=>[4.0], "θ" => [4.0])
  startingParameters = Dict{String,Array}("β" => startβ,"η"=>startη, "γ"=> startγ, "logT"=> startlogT, "logS"=>startlogS)
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-8,"xtol"=>1e-8,"omegameasure"=>omegameasure,"betaconstraint"=>1.0, "knitro"=>1.0, "maxiter"=>1000)
  if saveoutfilename!=""
    settings["savefitted"]=1
    settings["outfile"]=saveoutfilename
  end
  eval_ces_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

function solveWithθ_gm(startθ::Vector{Float64}, startη::Vector{Float64})

  fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("β"=>fixedβ)
  startingParameters = Dict{String,Array}("θ"=>startθ, "η"=>startη, "γ"=> 1/35 .* ones(Float64,35,35), "logT"=> zeros(Float64,109,35), "logS"=>zeros(Float64,109,35))
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>1.0,"betaconstraint"=>0.0)
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end
# version with starting values for all parameters
function solveWithθ_gm(startθ::Vector{Float64}, startη::Vector{Float64}, startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2}, saveoutfilename::String)

  fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("β"=>fixedβ)
  startingParameters = Dict{String,Array}("θ"=>startθ, "η"=>startη, "γ"=> startγ, "logT"=> startlogT, "logS"=>startlogS)
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>1.0,"betaconstraint"=>0.0)
  if saveoutfilename!=""
    settings["savefitted"]=1
    settings["outfile"]=saveoutfilename
  end
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

function solveWithθ_f(startθ::Vector{Float64}, startη::Vector{Float64})

  fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("β"=>fixedβ)
  startingParameters = Dict{String,Array}("θ"=>startθ, "η"=>startη, "γ"=> 1/35 .* ones(Float64,35,35), "logT"=> zeros(Float64,109,35), "logS"=>zeros(Float64,109,35))
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>2.0,"betaconstraint"=>0.0)
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end
# version with starting values for all parameters
function solveWithθ_f(startθ::Vector{Float64}, startη::Vector{Float64}, startγ::Array{Float64,2}, startlogT::Array{Float64,2}, startlogS::Array{Float64,2}, saveoutfilename::String)

  fixedβ=[0.5 0.0]'

  fixedParameters = Dict{String,Array}("β"=>fixedβ)
  startingParameters = Dict{String,Array}("θ"=>startθ, "η"=>startη, "γ"=> startγ, "logT"=> startlogT, "logS"=>startlogS)
  outputParameters=Dict()
  settings = Dict{String,Any}("ftol"=>1e-9,"xtol"=>1e-9,"omegameasure"=>2.0,"betaconstraint"=>0.0)
  if saveoutfilename!=""
    settings["savefitted"]=1
    settings["outfile"]=saveoutfilename
  end
  eval_problem!(fixedParameters, startingParameters, outputParameters, settings)

  return outputParameters
end

# get the structural T and S from the 'fixed effects'
# FETI and FESN should be in levels, not logs (i.e. positive only!)
function getDeepParameters!(θ::Float64, β::Array{Float64,1}, γ::Array{Float64,2}, FETi::Array{Float64,2}, FESn::Array{Float64,2},dni::Array{Float64,3},T::Array{Float64,2}, S::Array{Float64,2})

  # we have:
  # 'logS'= FESn = log(S/μ)
  # 'logT'= FETi = log(T*p^(-θ))

  #σ=3.5
  α=1.0#gamma((1-σ)./θ + 1).^(1.0 ./(1-σ));
  μ=1.0#σ./(σ-1)

  SDIM=35

  for c=1:109
    FESnc=repeat(FESn[c,:],1,SDIM);
    FETic=repeat(FETi[c,:]',SDIM,1);
    dnic=dni[c,:,:]
    #println(size(FESnc))
    #println(size(FETic))

    # calculate price vector
    pc = exp.(sum(γ .* log.((α ./ γ) .* (FESnc .* μ.^(-θ) + FETic .* μ.^(-θ) .* dnic .^ (-θ)) .^ (-1.0 ./ θ)),dims=2));

    # back out Ti and Sn
    S[c,:] = FESn[c,:]'.*μ.^(-θ);
    T[c,:] = FETi[c,:].*(pc.^θ);

  end
end
# # CES version of the above
# function getDeepParameters_ces!(ρ::Float64, θ::Float64, γ::Array{Float64,2}, FETi::Array{Float64,2}, FESn::Array{Float64,2},dni::Array{Float64,3},T::Array{Float64,2}, S::Array{Float64,2})

#   σ=3.5
#   α=gamma((1-σ) ./ θ + 1) .^ (1.0 ./ (1-σ));
#   μ=σ./(σ-1)

#   SDIM=35

#   for c=1:109


#     FESnc=repeat(FESn[c,:],1,SDIM);
#     FETic=repeat(FETi[c,:]',SDIM,1);
#     dnic=dni[c,:,:]

#     # calculate price vector
#     pc = ( sum( γ.*α.^(1-ρ).*( FESnc.*μ.^(-θ) + FETic.*μ.^(-θ).*dnic.^(-θ) ).^(-(1-ρ)/θ) ,dims=2) ).^(1/(1-ρ))

#     #@show pc

#     # back out Ti and Sn
#     S[c,:] = FESn[c,:]'.*μ.^(-θ);
#     T[c,:] = FETi[c,:].*(pc.^θ);


#   end
# end

# version where we pass dI as opposed to \eta and z
function solveCountries(θ::Float64, β::Vector{Float64}, dI::Array{Float64,3}, γ::Array{Float64,2}, FETi::Array{Float64,2}, FESn::Array{Float64,2}, plotfilename::String, plottitle::String)
  global isocodes, indices, lhs, δ, δt, countrynames, countrynames_full, consumptionShares

  
  outputPerCapitaBefore=zeros(Float64,109)
  outputPerCapitaAfter=zeros(Float64,109)
  outputPerCapitaAfter_US=zeros(Float64,109)
  percentageChange=zeros(Float64,109)
  percentageChange_US=zeros(Float64,109)

  ioshares_before=zeros(Float64,109,35,35)
  ioshares_after_us=zeros(Float64,109,35,35)

  # create dni
  dni = zeros(Float64,109,35,35)
  for c=1:109
    for n=1:35
      for i=1:35
        dni[c,n,i]=ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<dI[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),dI[c,n,i])
      end
    end
  end

  # convert fixed effects into structural parameters
  S=zeros(Float64,109,35)
  T=zeros(Float64,109,35)
  getDeepParameters!(θ, β, γ, FETi, FESn ,dni,T, S)

  #ioshares_before = zeros(Float64,109,35,35)
  #ioshares_after_us = zeros(Float64,109,35,35)

  for c=1:109
    dnic=dni[c,:,:];
    #dnius = squeeze(dniusfull[c,:,:],1)
    XniXn=zeros(Float64,35,35)
    (outputPerCapitaBefore[c],XniXn,y_before, p_before) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), dnic, consumptionShares[c,:])
    ioshares_before[c,:,:]=XniXn

    (outputPerCapitaAfter[c],XniXn,y_after,p_after) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), ones(Float64,35,35), consumptionShares[c,:])
    percentageChange[c]=(outputPerCapitaAfter[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]


    (outputPerCapitaAfter_US[c],XniXn, y_after_us, p_after_us) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), dni[106,:,:], consumptionShares[c,:])
    percentageChange_US[c]=(outputPerCapitaAfter_US[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]
    ioshares_after_us[c,:,:]=XniXn

    println("$(countrynames_full[c]),$(countrynames[c]), delta: $(δ[c,1,1]), To zero: $(percentageChange[c]), To US: $(percentageChange_US[c]), Frac dY: $(((y_after_us-y_before)/y_before)/(percentageChange_US[c])), DeltaP: $((p_after_us-p_before)/p_before))")
  end
  println("Mean: ", mean(percentageChange))
  println("Mean To US: ", mean(percentageChange_US))

  # do plot

  deltavec = δ[:,1,1]
  if plotfilename != ""
    Labels = countrynames
    xticks = [0.25, 0.5, 0.75, 1]

    # uncomment this when using Gadfly
    df = DataFrame(x = deltavec, y = 100.0 .* percentageChange_US, name = vec(countrynames))
    p=Gadfly.plot(df, x=:x, y=:y, Geom.point, label =:name, Geom.label, 
      Coord.Cartesian(xmin=0.08,xmax=1.08), Guide.xticks(ticks=xticks), 
      Guide.xlabel("Enforcement cost"), Guide.ylabel("Counterfactual welfare increase, in percent",orientation=:vertical), 
      Guide.title(plottitle))
    draw(SVG("$(plotfilename).svg",15cm, 10cm),p)
    run(`cairosvg $(plotfilename).svg -o $(plotfilename).pdf`)
  end

  return [percentageChange_US percentageChange], ioshares_before, ioshares_after_us

end

function solveCountries(θ::Float64, β::Vector{Float64}, η::Vector{Float64}, γ::Array{Float64,2}, omegameasure::Float64, FETi::Array{Float64,2}, FESn::Array{Float64,2}, plotfilename::String, plottitle::String)

  global isocodes, indices, lhs, z_gm, z_f, δ, δt, countrynames, countrynames_full, consumptionShares

  outputPerCapitaBefore=zeros(Float64,109)
  outputPerCapitaAfter=zeros(Float64,109)
  outputPerCapitaAfter_US=zeros(Float64,109)
  percentageChange=zeros(Float64,109)
  percentageChange_US=zeros(Float64,109)

  ioshares_before=zeros(Float64,109,35,35)
  ioshares_after_us=zeros(Float64,109,35,35)

  # create dni
  if omegameasure==2.0
    # z_f
    z=z_f
  else
    # otherwise z_gm
    z=z_gm
  end
  dni = zeros(Float64,109,35,35)
  dI = zeros(Float64,109,35,35)
  for c=1:109
    for n=1:35
      for i=1:35
        # CONST_DIM_ETA
        dI[c,n,i] = 1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i]
        dni[c,n,i]=ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<dI[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),dI[c,n,i])
      end
    end
  end


  # convert fixed effects into structural parameters
  S=zeros(Float64,109,35)
  T=zeros(Float64,109,35)
  getDeepParameters!(θ, β, γ, FETi, FESn ,dni,T, S)

  #ioshares_before = zeros(Float64,109,35,35)
  #ioshares_after_us = zeros(Float64,109,35,35)

  for c=1:109
    dnic=dni[c,:,:];
    #dnius = squeeze(dniusfull[c,:,:],1)
    XniXn=zeros(Float64,35,35)
    (outputPerCapitaBefore[c],XniXn,y_before, p_before) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), dnic, consumptionShares[c,:])
    ioshares_before[c,:,:]=XniXn

    (outputPerCapitaAfter[c],XniXn,y_after,p_after) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), ones(Float64,35,35), consumptionShares[c,:])
    percentageChange[c]=(outputPerCapitaAfter[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]


    (outputPerCapitaAfter_US[c],XniXn, y_after_us, p_after_us) = solveCountry(θ, γ, log.(T[c,:]), log.(S[c,:]), dni[106,:,:], consumptionShares[c,:])
    percentageChange_US[c]=(outputPerCapitaAfter_US[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]
    ioshares_after_us[c,:,:]=XniXn

    println("$(countrynames_full[c]),$(countrynames[c]), delta: $(δ[c,1,1]), To zero: $(percentageChange[c]), To US: $(percentageChange_US[c]), Frac dY: $(((y_after_us-y_before)/y_before)/(percentageChange_US[c])), DeltaP: $((p_after_us-p_before)/p_before))")
  end
  println("Mean: ", mean(percentageChange))
  println("Mean To US: ", mean(percentageChange_US))

  # do plot

  deltavec = δ[:,1,1]
  if plotfilename != ""
    Labels = countrynames
    xticks = [0.25, 0.5, 0.75, 1]

    # uncomment this when using Gadfly
    df = DataFrame(x = deltavec, y = 100.0 .* percentageChange_US, name = vec(countrynames))
    p=Gadfly.plot(df, x=:x, y=:y, Geom.point, label =:name, Geom.label, 
      Coord.Cartesian(xmin=0.08,xmax=1.08), Guide.xticks(ticks=xticks), 
      Guide.xlabel("Enforcement cost"), Guide.ylabel("Counterfactual welfare increase, in percent",orientation=:vertical), 
      Guide.title(plottitle))
    draw(SVG("$(plotfilename).svg",15cm, 10cm),p)
    run(`cairosvg $(plotfilename).svg -o $(plotfilename).pdf`)
  end

  return [percentageChange_US percentageChange], ioshares_before, ioshares_after_us
end

# # CES Version
# function solveCountries_ces(ρ::Float64, θ::Float64, β::Vector{Float64}, η::Vector{Float64}, γ::Array{Float64,2}, omegameasure::Float64, FETi::Array{Float64,2}, FESn::Array{Float64,2}, plotfilename::String, plottitle::String)

#   global isocodes, indices, lhs, z_gm, z_f, δ, δt, countrynames, countrynames_full, consumptionShares

#   outputPerCapitaBefore=zeros(Float64,109)
#   outputPerCapitaAfter=zeros(Float64,109)
#   outputPerCapitaAfter_US=zeros(Float64,109)
#   percentageChange=zeros(Float64,109)
#   percentageChange_US=zeros(Float64,109)
#   ioshares_before=zeros(Float64,109,35,35)
#   ioshares_after_us=zeros(Float64,109,35,35)

#   #@show β
#   #@show η

#   # create dni
#   if omegameasure==2.0
#     # z_f
#     z=z_f
#   else
#     # otherwise z_gm
#     z=z_gm
#   end
#   dni=zeros(Float64,109,35,35)
#   for c=1:109
#     for n=1:35
#       for i=1:35
#         dni[c,n,i]=ifelse(1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i])<1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i],1/(1-β[1]*δ[c,n,i]-β[2]*δ[c,n,i]*δ[c,n,i]),1+η[1]*z[c,n,i]+η[2]*z[c,n,i]*z[c,n,i])
#       end
#     end
#   end


#   # convert fixed effects into structural parameters
#   S=zeros(Float64,109,35)
#   T=zeros(Float64,109,35)
#   getDeepParameters_ces!(ρ, θ, γ, FETi, FESn ,dni,T, S)

#   #ioshares_before = zeros(Float64,109,35,35)
#   #ioshares_after_us = zeros(Float64,109,35,35)

#   for c=1:109
#     dnic=dni[c,:,:];
#     #dnius = squeeze(dniusfull[c,:,:],1)
#     XniXn=zeros(Float64,35,35)
#     (outputPerCapitaBefore[c],XniXn) = solveCountry_ces(ρ, θ, β, η, γ, log.(T[c,:]), log.(S[c,:]), dnic, consumptionShares[c,:])
#     ioshares_before[c,:,:]=XniXn

#     (outputPerCapitaAfter[c],XniXn) = solveCountry_ces(ρ, θ, β, η, γ, log.(T[c,:]), log.(S[c,:]), ones(Float64,35,35), consumptionShares[c,:])
#     percentageChange[c]=(outputPerCapitaAfter[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]


#     (outputPerCapitaAfter_US[c],XniXn) = solveCountry_ces(ρ, θ, β, η, γ, log.(T[c,:]), log.(S[c,:]), dni[106,:,:], consumptionShares[c,:])
#     percentageChange_US[c]=(outputPerCapitaAfter_US[c]-outputPerCapitaBefore[c])/outputPerCapitaBefore[c]
#     ioshares_after_us[c,:,:]=XniXn

#     println("$(countrynames_full[c]),$(countrynames[c]), delta: $(δ[c,1,1]), To zero: $(percentageChange[c]), To US: $(percentageChange_US[c])")
#   end
#   println("Mean: ", mean(percentageChange))
#   println("Mean To US: ", mean(percentageChange_US))

#   # do plot

#   deltavec = δ[:,1,1]

#   Labels = countrynames
#   xticks = [0.25, 0.5, 0.75, 1]

#   # uncomment this when using Gadfly
#   # p=plot(x=deltavec[percentageChange_US.<0.4], y=100*percentageChange_US[percentageChange_US.<0.4],label=vec(countrynames)[percentageChange_US.<0.4], Geom.point, Geom.label, Coord.cartesian(xmin=0.08,xmax=1.08), Guide.xticks(ticks=xticks), Guide.xlabel("Enforcement cost"), Guide.ylabel("Counterfactual welfare increase, in percent",orientation=:vertical), Guide.title(plottitle))
#   # draw(PNG(string(plotfilename,".png"), 12cm, 6cm), p)
#   # draw(SVGJS(string(plotfilename,".svg"), 16cm, 12cm), p)

#   return [percentageChange_US percentageChange], ioshares_before, ioshares_after_us
# end


# # Version where we just supply a δ
# function solveCountry(θ::Float64, β::Vector{Float64}, η::Vector{Float64}, γ::Array{Float64,2}, logT::Array{Float64,1}, logS::Array{Float64,1}, δnew::Float64)
#   # CONST_DIM_ETA
#   dninew = ifelse(1+bestβ[1].*δnew+bestβ[2].*δnew.*δnew.<1+bestη[1].*z^bestη[2],1+bestβ[1].*δnew+bestβ[2].*δnew.*δnew,1+bestη[1].*z^bestη[2])
#   solveCountry(θ, γ, logT, logS, dninew, consumptionShares)

# end
# function solveCountry_ces(ρ::Float64, θ::Float64, β::Vector{Float64}, η::Vector{Float64}, γ::Array{Float64,2}, logT::Array{Float64,1}, logS::Array{Float64,1}, δnew::Float64)

#   dninew = ifelse(1+bestβ[1].*δnew+bestβ[2].*δnew.*δnew.<1+bestη[1].*z+bestη[2].*z.*z,1+bestβ[1].*δnew+bestβ[2].*δnew.*δnew,1+bestη[1].*z+bestη[2].*z.*z)
#   solveCountry_ces(ρ, θ, β, η, γ, logT, logS, dninew, consumptionShares)

# end

# this function solves equilibrium given the parameters of the model. Note that logT, logS etc are
# NOT the c_i and c_n components from the estimation (these include the p_i as well!)
function solveCountry(θ::Float64, γ::Array{Float64,2}, logT::Array{Float64,1}, logS::Array{Float64,1}, dni::Array{Float64,2}, consumptionShares::Array{Float64,1})

  # TODO check formula for hh price index

  # logT is a 1x35 Vector
  # logS is a 35x1 Vector
  # γ is 35x35

  # calibration
  #σ=3.5
  α=1.0#gamma((1-σ) ./ θ + 1).^(1.0 ./ (1-σ));
  μ=1.0#σ./(σ-1)

  SDIM = 35

  # set up matrices
  repS = repeat(exp.(logS),1,SDIM)
  repT = repeat(exp.(logT'),SDIM,1)


  # TODO change This


  # solve for sectoral price levels
  pvec = ones(Float64,SDIM,1)
  pvec_old = zeros(Float64,SDIM,1)
  while (norm(pvec - pvec_old)>=0.000001)
    pvec_old = pvec

    pvec = exp.(sum(γ .* log.((α./γ) .* (repS + repT .* (μ .* repeat(pvec',SDIM,1).*dni).^(-θ)).^(-1.0 ./ θ)),dims=2))
  end

  # now calculate X_ni/X_n
  XniXn = γ ./ (1.0 .+ repS./(repT .* (μ .* repeat(pvec',SDIM,1) .* dni).^(-θ)))

  # calculate profits
  function excessProfits(πOverL)
    Xnvec = (Matrix(1I, SDIM, SDIM) .- ((1.0 ./μ).*XniXn)')\(consumptionShares.*(1+πOverL))
    #excessProfits = πOverL - sum((1 - 1.0 ./(μ.*dni)).*XniXn,2)'*Xnvec
    return (πOverL .- sum((1.0 - 1.0 ./ μ).*XniXn,dims=2)' * Xnvec)[1]
  end
  #πOverL = nlsolve(excessProfits!,0.1,autodiff=true, method=Brent())

  πOverL = fzero(excessProfits,0,1000)

  # check this
  priceOfHouseholds = exp.(consumptionShares'*log.(pvec./consumptionShares))[1]

  # output per capita
  outputPerCapita = (1+πOverL)/priceOfHouseholds

  return outputPerCapita, XniXn, 1+πOverL, priceOfHouseholds
end

# # CES version of the above
# function solveCountry_ces(ρ::Float64, θ::Float64, β::Vector{Float64}, η::Vector{Float64}, γ::Array{Float64,2}, logT::Array{Float64,1}, logS::Array{Float64,1}, dni::Array{Float64,2}, consumptionShares::Array{Float64,1})

#   σ=3.5
#   α=gamma((1-σ)./θ + 1).^(1.0 ./(1-σ));
#   μ=σ./(σ-1)
#   SDIM = 35

#   # set up matrices
#   repS = repeat(exp.(logS),1,SDIM)
#   repT = repeat(exp.(logT'),SDIM,1)

#   #@show logS
#   #@show logT
#   #@show γ
#   #@show dni

#   # solve for sectoral price levels iteratively by Banach iteration
#   pvec = ones(Float64,SDIM,1)
#   pvec_old = zeros(Float64,SDIM,1)
#   while (norm(pvec - pvec_old)>=0.000001)
#     pvec_old = pvec
#     pvec = ( sum( γ.*α.^(1-ρ).*( repS + repT.*(μ.*repeat(pvec',SDIM,1).*dni).^(-θ) ).^(-(1-ρ)/θ) ,dims=2) ).^(1/(1-ρ))
#   end

#   #@show pvec

#   # now calculate X_ni/X_n
#   #XniXn = γ./(1+ repS./(repT.*(μ .* repmat(pvec',SDIM,1).*dni).^(-θ)))
#   XniXn =  ( γ .* ( α.*( repS + repT.*(μ.*repeat(pvec',SDIM,1).*dni).^(-θ) ).^(-1/θ) ).^(1-ρ) )./( repeat(sum( γ .* ( α.*( repS + repT.*(μ.*repeat(pvec',SDIM,1).*dni).^(-θ) ).^(-1/θ) ).^(1-ρ),2),1,SDIM) ) .* 1.0 ./(1+ repS./(repT.*(μ .* repmat(pvec',SDIM,1).*dni).^(-θ)))

# # calculate profits
#   function excessProfits(πOverL)
#     Xnvec = (eye(SDIM) .- ((1.0 ./μ).*XniXn)')\(consumptionShares.*(1+πOverL))
#     #excessProfits = πOverL - sum((1 - 1.0 ./(μ.*dni)).*XniXn,2)'*Xnvec
#     return (πOverL - sum((1.0 - 1.0 ./μ).*XniXn,2)' * Xnvec)[1]
#   end
#   #πOverL = nlsolve(excessProfits!,0.1,autodiff=true, method=Brent())

#   #@show mean(sum(XniXn,2))

#   πOverL = fzero(excessProfits,0,1000)
#   #@show πOverL

#   # check this
#   priceOfHouseholds = exp.(consumptionShares'*log.(pvec./consumptionShares))[1]

#   # output per capita
#   outputPerCapita = (1+πOverL)/priceOfHouseholds

#   return outputPerCapita, XniXn

# end


end
